{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Open Images v6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Open Images](https://storage.googleapis.com/openimages/web/download.html)是一个大型图像的数据集。这些数据有着图像级别的标注以及数千个类别的边界框。第六版数据集的物体检测子集包含600类物体的190万张图像。\n",
    "\n",
    "这份Notebook演示了如何从该数据集中提取人脸的边界框。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 准备工作\n",
    "首先载入一些数据分析必要的模块。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Open Images处理图像外的数据存储为csv格式，pandas是处理csv文件的不二之选。\n",
    "import pandas as pd\n",
    "\n",
    "# 处理的过程中涉及到一些数学运算。\n",
    "import numpy as np\n",
    "\n",
    "# Matplotlib用来绘制图像与边界框。\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Rectangle\n",
    "\n",
    "# os包用来在不同系统下读取文件。\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于下载好的所有文件，我们采用如下文件存储约定：\n",
    "```\n",
    "-open_images_v4        数据集根目录\n",
    "|-annotation           所有csv文件\n",
    "|-train                训练用图像文件夹\n",
    "|-validation           验证用图像文件夹\n",
    "|-open_images.ipynb    本notebook\n",
    "```\n",
    "\n",
    "设定数据集根目录："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_dir = \"/home/robin/hdd/data/raw/open_images/v6\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "随机载入一张图片来测试数据集位置是否正确。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "val_img = plt.imread(os.path.join(root_dir,'train','7f111c25a72d31d6.jpg'))\n",
    "plt.imshow(val_img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果一切顺利，一张五口之家的图像应当显示在上方。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据分析\n",
    "并非所有的Open Images数据都有边界框。物体检测这部分数据被单独分了出来，称为Subset with Bounding Boxes (600 classes)。这个数据子集的标注与metadata包括：\n",
    "\n",
    "- Boxes\n",
    "- Segmentations\n",
    "- Relationships\n",
    "- Localized narratives\n",
    "- Image labels\n",
    "- Image IDs\n",
    "- Metadata\n",
    "\n",
    "其中Boxes包含我们感兴趣的边界框。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Boxes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Boxes顾名思义存储了边界框。载入训练boxes对应文件并查看其内容。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "boxes = pd.read_csv(os.path.join(root_dir, 'annotation/oidv6-train-annotations-bbox.csv'))\n",
    "print(\"Total records: {}\".format(boxes['Source'].count()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "提取前5项看看。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "boxes.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "其中XMin, XMax, YMin, YMax是归一化的边界框坐标。同时有两个重要的参数后边会用到：\n",
    "\n",
    "- IsGroupOf: 如果边界框内同时包含了多个同类物体，则该项为1。\n",
    "- IsDepiction: 如果人脸为卡通人物或者绘画作品，该项为1。\n",
    "\n",
    "数据格式的完整描述可以在[官网](https://storage.googleapis.com/openimages/web/download.html)查到。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metadata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "载入Metadata并查看其内容。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "metadata = pd.read_csv(os.path.join(root_dir, 'annotation/class-descriptions-boxable.csv'), header=None)\n",
    "metadata.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这个文件非常简单，是类别编码与类别名的映射。以上几个文件所包含的信息足够我们开始提取人脸的工作了。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 提取人脸\n",
    "首先，我们需要找到\"Human face\"所对应的类别代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "face_label = metadata[metadata[1] == \"Human face\"].iat[0, 0]\n",
    "print(\"对应人脸Human face的类别代码是：{}\".format(face_label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们找到属于该类别的所有图像。同时对图像进行初步的筛选：\n",
    "\n",
    "- 每个边界框内只包含一个人脸。\n",
    "- 必须是真人的人脸。（可选）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_face = boxes['LabelName'] == face_label\n",
    "is_individual = boxes['IsGroupOf'] == 0\n",
    "is_not_depiction = boxes['IsDepiction'] == 0\n",
    "face_anns = boxes[is_face & is_individual]\n",
    "print(\"筛选后获得数据总数：{}\".format(face_anns['ImageID'].count()))\n",
    "face_anns.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来随机选择一幅图像，并将标注绘制在图像上。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "img_id = np.random.choice(face_anns['ImageID'])\n",
    "print(\"Image id: {}\".format(img_id))\n",
    "img = plt.imread(os.path.join(root_dir, 'train', img_id+'.jpg'))\n",
    "\n",
    "# Be careful sometimes the image is of gray format that there is only one channel. As the neural networks most likely require a fixed input channel, it would be better to convert the image into 3 channel.\n",
    "img_height, img_width = img.shape[:2]\n",
    "\n",
    "# Try to draw the annotation.\n",
    "chosen_anns = face_anns[face_anns['ImageID'] == img_id]\n",
    "bboxes = chosen_anns.loc[:, ['XMin', 'XMax', 'YMin', 'YMax']].values\n",
    "currentAxis = plt.gca()\n",
    "for each_box in bboxes:\n",
    "    rect = Rectangle((each_box[0]*img_width, each_box[2]*img_height), \n",
    "                     (each_box[1] - each_box[0])*img_width, \n",
    "                     (each_box[3] - each_box[2])*img_height,\n",
    "                     linewidth = 2,\n",
    "                     edgecolor = 'c',\n",
    "                     facecolor='None'\n",
    "                    )\n",
    "    currentAxis.add_patch(rect)\n",
    "plt.imshow(img)"
   ]
  },
  {
   "source": [
    "# 转换为TFRecord文件\n",
    "载入必要的工具包。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 进度条可以使得漫长的处理过程直观化\n",
    "from tqdm import tqdm\n",
    "\n",
    "# TensorFlow\n",
    "import tensorflow as tf\n",
    "\n",
    "# 制作TFRecord文件的工具\n",
    "from tf_record_generator import DetectionSample, create_tf_example"
   ]
  },
  {
   "source": [
    "设定需要写入TFRecord文件的数量。获取去重后的图像ID列表。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_train = 10000\n",
    "num_val = 1000\n",
    "\n",
    "image_ids = list(set(face_anns['ImageID']))\n",
    "\n",
    "print(\"图像总数为：{}，训练数据个数：{}，验证数据个数：{}\".format(len(image_ids), num_train, num_val))\n"
   ]
  },
  {
   "source": [
    "设定TFRecord文件的存储位置。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "record_file_train = \"/home/robin/data/face/oid6/train.record\"\n",
    "record_file_val = \"/home/robin/data/face/oid6/val.record\""
   ]
  },
  {
   "source": [
    "定义写入文件函数，供训练与验证共享使用。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_record(image_ids, num_samples, record_file):\n",
    "    num_valid_samples = 0\n",
    "    num_invalid_samples = 0\n",
    "    writer = tf.io.TFRecordWriter(record_file)\n",
    "\n",
    "    for image_id in tqdm(image_ids, total=num_samples):\n",
    "        # Break if enough samples written.\n",
    "        if num_valid_samples == num_samples:\n",
    "            break\n",
    "\n",
    "        # Get image file path.\n",
    "        image_file = os.path.join(root_dir, 'train', image_id+'.jpg')\n",
    "\n",
    "        # Get face annotations.\n",
    "        chosen_anns = face_anns[face_anns['ImageID'] == image_id]\n",
    "        bboxes = chosen_anns.loc[:, ['XMin', 'YMin', 'XMax', 'YMax']].values\n",
    "\n",
    "        # Creat detection example.\n",
    "        example = DetectionSample(image_file, bboxes)\n",
    "        tf_example = create_tf_example(example, min_size=None)\n",
    "        if tf_example is not None:\n",
    "            writer.write(tf_example.SerializeToString())\n",
    "            num_valid_samples += 1\n",
    "        else:\n",
    "            num_invalid_samples += 1\n",
    "    \n",
    "    # 收尾工作\n",
    "    writer.close()\n",
    "    print(\"文件写入完成：{}\".format(record_file))\n",
    "\n",
    "    return num_invalid_samples\n"
   ]
  },
  {
   "source": [
    "生成训练数据文件。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_invalid_samples = write_record(image_ids, num_train, record_file_train)"
   ]
  },
  {
   "source": [
    "生成验证用文件。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = write_record(image_ids[num_train+num_invalid_samples:], num_val, record_file_val)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}